#!/bin/bash
export CUDA_VISIBLE_DEVICES=$1

rank=$2

# nohup bash train_flux_kontext/train_flux_kontext_puthere_dreambooth.sh 2 4 > train_flux_kontext/train_flux_kontext_puthere_dreambooth_rank4.log 2>&1 &

dataset_dir=/mnt/nas/shengjie/datasets/KontextRefControl_puthere

python train_flux_kontext/train_flux_kontext_dreamboot_general.py \
  --pretrained_model_name_or_path "/data/models/FLUX.1-Kontext-dev" \
  --cache_dir /mnt/nas/shengjie/cache \
  --instance_data_dir "$dataset_dir"/train/metadata.jsonl  \
  --image_column "file_name" \
  --cond_image_column "control_image" \
  --caption_column "prompt" \
  --aspect_ratio_buckets 1024,1024 \
  --train_batch_size 2 \
  --num_train_epochs 80 \
  --max_train_steps 8000 \
  --gradient_accumulation_steps 2 \
  --guidance_scale 2.5 \
  --learning_rate 1e-4 \
  --lr_scheduler "cosine" \
  --rank $rank \
  --lora_alpha $rank \
  --lora_dropout 0 \
  --output_dir /mnt/nas/shengjie/puthere_output_"$rank"_0922 \
  --checkpointing_steps 500 \
  --mixed_precision "bf16" \
  --gradient_checkpointing \
  --use_8bit_adam \
  --cache_latents \
  --seed 42 \